from typing import Any, Dict, List, Optional, Union

from llama_index.core.base.embeddings.base import (
    DEFAULT_EMBED_BATCH_SIZE,
    BaseEmbedding,
)
from llama_index.core.bridge.pydantic import Field, PrivateAttr, model_validator
from llama_index.core.callbacks.base import CallbackManager
from llama_index.embeddings.oci_data_science.client import AsyncClient, Client

# from llama_index.embeddings.oci_data_science.utils import _validate_dependency

DEFAULT_MODEL = "odsc-embeddings"
DEFAULT_TIMEOUT = 120
DEFAULT_MAX_RETRIES = 5


class OCIDataScienceEmbedding(BaseEmbedding):
    """
    Embedding class for OCI Data Science models.

    This class provides methods to generate embeddings using models deployed on
    Oracle Cloud Infrastructure (OCI) Data Science. It supports both synchronous
    and asynchronous requests and handles authentication, batching, and other
    configurations.

    Setup:
        Install the required packages:
        ```bash
        pip install -U oracle-ads llama-index-embeddings-oci-data-science
        ```

        Configure authentication using `ads.set_auth()`. For example, to use OCI
        Resource Principal for authentication:
        ```python
        import ads
        ads.set_auth("resource_principal")
        ```

        For more details on authentication, see:
        https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html

        Ensure you have the required policies to access the OCI Data Science Model
        Deployment endpoint:
        https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-policies-auth.htm

        To learn more about deploying LLM models in OCI Data Science, see:
        https://docs.oracle.com/en-us/iaas/data-science/using/ai-quick-actions-model-deploy.htm

    Examples:
        Basic Usage:
        ```python
        import ads
        from llama_index.embeddings.oci_data_science import OCIDataScienceEmbedding

        ads.set_auth(auth="security_token", profile="OC1")

        embeddings = OCIDataScienceEmbedding(
            endpoint="https://<MD_OCID>/predict",
        )

        e1 = embeddings.get_text_embedding("This is a test document")
        print(e1)

        e2 = embeddings.get_text_embedding_batch([
            "This is a test document",
            "This is another test document"
        ])
        print(e2)
        ```

        Asynchronous Usage:
        ```python
        import ads
        import asyncio
        from llama_index.embeddings.oci_data_science import OCIDataScienceEmbedding

        ads.set_auth(auth="security_token", profile="OC1")

        embeddings = OCIDataScienceEmbedding(
            endpoint="https://<MD_OCID>/predict",
        )

        async def async_embedding():
            e1 = await embeddings.aget_query_embedding("This is a test document")
            print(e1)

        asyncio.run(async_embedding())
        ```

    Attributes:
        endpoint (str): The URI of the endpoint from the deployed model.
        auth (Dict[str, Any]): The authentication dictionary used for OCI API requests.
        model_name (str): The name of the OCI Data Science embedding model.
        embed_batch_size (int): The batch size for embedding calls.
        additional_kwargs (Dict[str, Any]): Additional keyword arguments for the OCI Data Science AI request.
        default_headers (Dict[str, str]): The default headers for API requests.

    """

    endpoint: str = Field(
        default=None, description="The URI of the endpoint from the deployed model."
    )

    auth: Union[Dict[str, Any], None] = Field(
        default_factory=dict,
        exclude=True,
        description=(
            "The authentication dictionary used for OCI API requests. "
            "If not provided, it will be autogenerated based on environment variables."
        ),
    )
    model_name: Optional[str] = Field(
        default=DEFAULT_MODEL,
        description="The name of the OCI Data Science embedding model to use.",
    )

    embed_batch_size: int = Field(
        default=DEFAULT_EMBED_BATCH_SIZE,
        description="The batch size for embedding calls.",
        gt=0,
        le=2048,
    )

    max_retries: int = Field(
        default=DEFAULT_MAX_RETRIES,
        description="The maximum number of API retries.",
        ge=0,
    )

    timeout: float = Field(
        default=DEFAULT_TIMEOUT, description="The timeout to use in seconds.", ge=0
    )

    additional_kwargs: Optional[Dict[str, Any]] = Field(
        default_factory=dict,
        description="Additional keyword arguments for the OCI Data Science AI request.",
    )

    default_headers: Optional[Dict[str, str]] = Field(
        default_factory=dict, description="The default headers for API requests."
    )

    _client: Client = PrivateAttr()
    _async_client: AsyncClient = PrivateAttr()

    def __init__(
        self,
        endpoint: str,
        model_name: Optional[str] = DEFAULT_MODEL,
        auth: Dict[str, Any] = None,
        timeout: Optional[float] = DEFAULT_TIMEOUT,
        max_retries: Optional[int] = DEFAULT_MAX_RETRIES,
        embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
        additional_kwargs: Optional[Dict[str, Any]] = None,
        default_headers: Optional[Dict[str, str]] = None,
        callback_manager: Optional[CallbackManager] = None,
        **kwargs: Any,
    ) -> None:
        """
        Initialize the OCIDataScienceEmbedding instance.

        Args:
            endpoint (str): The URI of the endpoint from the deployed model.
            model_name (Optional[str]): The name of the OCI Data Science embedding model to use. Defaults to "odsc-embeddings".
            auth (Optional[Dict[str, Any]]): The authentication dictionary for OCI API requests. Defaults to None.
            timeout (Optional[float]): The timeout setting for the HTTP request in seconds. Defaults to 120.
            max_retries (Optional[int]): The maximum number of retry attempts for the request. Defaults to 5.
            embed_batch_size (int): The batch size for embedding calls. Defaults to DEFAULT_EMBED_BATCH_SIZE.
            additional_kwargs (Optional[Dict[str, Any]]): Additional arguments for the OCI Data Science AI request. Defaults to None.
            default_headers (Optional[Dict[str, str]]): The default headers for API requests. Defaults to None.
            callback_manager (Optional[CallbackManager]): A callback manager for handling events during embedding operations. Defaults to None.
            **kwargs: Additional keyword arguments.

        """
        super().__init__(
            model_name=model_name,
            endpoint=endpoint,
            auth=auth,
            embed_batch_size=embed_batch_size,
            timeout=timeout,
            max_retries=max_retries,
            additional_kwargs=additional_kwargs or {},
            default_headers=default_headers or {},
            callback_manager=callback_manager,
            **kwargs,
        )

    @model_validator(mode="before")
    # @_validate_dependency
    def validate_env(cls, values: Dict[str, Any]) -> Dict[str, Any]:
        """
        Validate the environment and dependencies before initialization.

        Args:
            values (Dict[str, Any]): The values passed to the model.

        Returns:
            Dict[str, Any]: The validated values.

        Raises:
            ImportError: If required dependencies are missing.

        """
        return values

    @property
    def client(self) -> Client:
        """
        Return the synchronous client instance.

        Returns:
            Client: The synchronous client for interacting with the OCI Data Science Model Deployment endpoint.

        """
        if not hasattr(self, "_client") or self._client is None:
            self._client = Client(
                endpoint=self.endpoint,
                auth=self.auth,
                retries=self.max_retries,
                timeout=self.timeout,
            )
        return self._client

    @property
    def async_client(self) -> AsyncClient:
        """
        Return the asynchronous client instance.

        Returns:
            AsyncClient: The asynchronous client for interacting with the OCI Data Science Model Deployment endpoint.

        """
        if not hasattr(self, "_async_client") or self._async_client is None:
            self._async_client = AsyncClient(
                endpoint=self.endpoint,
                auth=self.auth,
                retries=self.max_retries,
                timeout=self.timeout,
            )
        return self._async_client

    @classmethod
    def class_name(cls) -> str:
        """
        Get the class name.

        Returns:
            str: The name of the class.

        """
        return "OCIDataScienceEmbedding"

    def _get_query_embedding(self, query: str) -> List[float]:
        """
        Generate an embedding for a query string.

        Args:
            query (str): The query string for which to generate an embedding.

        Returns:
            List[float]: The embedding vector for the query.

        """
        return self.client.embeddings(
            input=query, payload=self.additional_kwargs, headers=self.default_headers
        )["data"][0]["embedding"]

    def _get_text_embedding(self, text: str) -> List[float]:
        """
        Generate an embedding for a text string.

        Args:
            text (str): The text string for which to generate an embedding.

        Returns:
            List[float]: The embedding vector for the text.

        """
        return self.client.embeddings(
            input=text, payload=self.additional_kwargs, headers=self.default_headers
        )["data"][0]["embedding"]

    async def _aget_text_embedding(self, text: str) -> List[float]:
        """
        Asynchronously generate an embedding for a text string.

        Args:
            text (str): The text string for which to generate an embedding.

        Returns:
            List[float]: The embedding vector for the text.

        """
        response = await self.async_client.embeddings(
            input=text, payload=self.additional_kwargs, headers=self.default_headers
        )
        return response["data"][0]["embedding"]

    def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
        """
        Generate embeddings for a list of text strings.

        Args:
            texts (List[str]): A list of text strings for which to generate embeddings.

        Returns:
            List[List[float]]: A list of embedding vectors corresponding to the input texts.

        """
        response = self.client.embeddings(
            input=texts, payload=self.additional_kwargs, headers=self.default_headers
        )
        return [raw["embedding"] for raw in response["data"]]

    async def _aget_query_embedding(self, query: str) -> List[float]:
        """
        Asynchronously generate an embedding for a query string.

        Args:
            query (str): The query string for which to generate an embedding.

        Returns:
            List[float]: The embedding vector for the query.

        """
        response = await self.async_client.embeddings(
            input=query, payload=self.additional_kwargs, headers=self.default_headers
        )
        return response["data"][0]["embedding"]

    async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
        """
        Asynchronously generate embeddings for a list of text strings.

        Args:
            texts (List[str]): A list of text strings for which to generate embeddings.

        Returns:
            List[List[float]]: A list of embedding vectors corresponding to the input texts.

        """
        response = await self.async_client.embeddings(
            input=texts, payload=self.additional_kwargs, headers=self.default_headers
        )
        return [raw["embedding"] for raw in response["data"]]
